library(rvest)
url <- "https://wiki.socr.umich.edu/index.php/SOCR_Data_July2009_ID_NI"
web_data <- read_html(url)
df <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
str(df)## tibble [672 × 13] (S3: tbl_df/tbl/data.frame)
## $ Subject_ID: int [1:672] 1 1 1 1 1 1 1 1 1 1 ...
## $ Group : chr [1:672] "AD" "AD" "AD" "AD" ...
## $ MMSE : int [1:672] 21 21 21 21 21 21 21 21 21 21 ...
## $ CDR : num [1:672] 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 ...
## $ Sex : logi [1:672] FALSE FALSE FALSE FALSE FALSE FALSE ...
## $ Age : int [1:672] 82 82 82 82 82 82 82 82 82 82 ...
## $ TBV : int [1:672] 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 ...
## $ GMV : int [1:672] 522930 522930 522930 522930 522930 522930 522930 522930 522930 522930 ...
## $ WMV : int [1:672] 247583 247583 247583 247583 247583 247583 247583 247583 247583 247583 ...
## $ CSFV : int [1:672] 281194 281194 281194 281194 281194 281194 281194 281194 281194 281194 ...
## $ ROI : int [1:672] 1 1 1 1 2 2 2 2 3 3 ...
## $ Measure : chr [1:672] "SA" "SI" "CV" "FD" ...
## $ Value : num [1:672] 14704.29 0.43 0.08 2.15 9084.68 ...
## # A tibble: 6 × 13
## Subject_ID Group MMSE CDR Sex Age TBV GMV WMV CSFV ROI
## <int> <chr> <int> <dbl> <lgl> <int> <int> <int> <int> <int> <int>
## 1 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 1
## 2 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 1
## 3 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 1
## 4 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 1
## 5 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 2
## 6 1 AD 21 1.2 FALSE 82 1051706 522930 247583 281194 2
## # ℹ 2 more variables: Measure <chr>, Value <dbl>
df$Group <- ifelse(df$Group %in% c("MCI", "AD"), "Patients", df$Group)
# Assuming your dataframe is named 'df' and the problematic columns are 'Group', 'Sex', and 'Measure'
df$Measure <- as.numeric(as.factor(df$Measure))
df$Sex <- as.numeric(as.factor(df$Sex))
str(df)## tibble [672 × 13] (S3: tbl_df/tbl/data.frame)
## $ Subject_ID: int [1:672] 1 1 1 1 1 1 1 1 1 1 ...
## $ Group : chr [1:672] "Patients" "Patients" "Patients" "Patients" ...
## $ MMSE : int [1:672] 21 21 21 21 21 21 21 21 21 21 ...
## $ CDR : num [1:672] 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 ...
## $ Sex : num [1:672] 1 1 1 1 1 1 1 1 1 1 ...
## $ Age : int [1:672] 82 82 82 82 82 82 82 82 82 82 ...
## $ TBV : int [1:672] 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 ...
## $ GMV : int [1:672] 522930 522930 522930 522930 522930 522930 522930 522930 522930 522930 ...
## $ WMV : int [1:672] 247583 247583 247583 247583 247583 247583 247583 247583 247583 247583 ...
## $ CSFV : int [1:672] 281194 281194 281194 281194 281194 281194 281194 281194 281194 281194 ...
## $ ROI : int [1:672] 1 1 1 1 2 2 2 2 3 3 ...
## $ Measure : num [1:672] 3 4 1 2 3 4 1 2 3 4 ...
## $ Value : num [1:672] 14704.29 0.43 0.08 2.15 9084.68 ...
## # A tibble: 6 × 13
## Subject_ID Group MMSE CDR Sex Age TBV GMV WMV CSFV ROI
## <int> <chr> <int> <dbl> <dbl> <int> <int> <int> <int> <int> <int>
## 1 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 1
## 2 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 1
## 3 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 1
## 4 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 1
## 5 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 2
## 6 1 Patients 21 1.2 1 82 1051706 522930 247583 281194 2
## # ℹ 2 more variables: Measure <dbl>, Value <dbl>
This code pools certain groups into a single category, converts two columns from categorical to numeric coding for analysis, and then displays the structure of the modified dataframe.
#install.packages("remotes")
#remotes::install_github("rstudio/tensorflow", force = T)
#install.packages("reticulate")
library(reticulate)
use_python("C:/Users/Jun/AppData/Local/Programs/Python/Python38/python.exe")
#py_install("Pillow", pip = TRUE)
#tensorflow::install_tensorflow()
library(tensorflow)
#tensorflow::install_tensorflow(extra_packages = "pillow")
#tensorflow::install_tensorflow(version = "2.13.*")
#virtualenv_create("r-tensorflow")
#use_virtualenv("r-tensorflow", required = TRUE)
#devtools::install_github("rstudio/keras")
library(keras)
#install_keras()
py_config()## python: C:/Users/Jun/AppData/Local/Programs/Python/Python38/python.exe
## libpython: C:/Users/Jun/AppData/Local/Programs/Python/Python38/python38.dll
## pythonhome: C:/Users/Jun/AppData/Local/Programs/Python/Python38
## version: 3.8.2 (tags/v3.8.2:7b3ab59, Feb 25 2020, 23:03:10) [MSC v.1916 64 bit (AMD64)]
## Architecture: 64bit
## numpy: C:/Users/Jun/AppData/Local/Programs/Python/Python38/Lib/site-packages/numpy
## numpy_version: 1.24.3
## tensorflow: C:\Users\Jun\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\__init__.p
##
## NOTE: Python version was forced by use_python() function
set.seed(2024) # for reproducibility
features <- df[,-c(1,2)]
target <- df[, 2]
target <- ifelse(target == "Patients", 1, 0) # Binary encoding for target
indices <- sample(1:nrow(features), size = 0.8 * nrow(features), replace = FALSE)
train_x <- features[indices, ]
train_y <- target[indices,]
test_x <- features[-indices, ]
test_y <- target[-indices,]
# Define the model
model <- keras_model_sequential() %>%
layer_dense(units = 8, activation = 'relu', input_shape = c(ncol(train_x))) %>%
layer_dense(units = 1, activation = 'sigmoid')
# Compile the model
model %>% compile(
loss = 'binary_crossentropy',
optimizer = 'adam',
metrics = 'accuracy'
)
# Assuming `train_x` is a data frame, convert it to a matrix
train_x_matrix <- as.matrix(train_x)
train_y_matrix <- as.matrix(train_y)
history <- model %>% fit(
train_x_matrix,
train_y_matrix,
epochs = 15,
batch_size = 5,
validation_split = 0.2
)## Epoch 1/15
## 86/86 - 1s - loss: 56296.5625 - accuracy: 0.4709 - val_loss: 667.0047 - val_accuracy: 0.7222 - 963ms/epoch - 11ms/step
## Epoch 2/15
## 86/86 - 0s - loss: 39.2112 - accuracy: 0.9371 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 238ms/epoch - 3ms/step
## Epoch 3/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 194ms/epoch - 2ms/step
## Epoch 4/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 206ms/epoch - 2ms/step
## Epoch 5/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 208ms/epoch - 2ms/step
## Epoch 6/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 209ms/epoch - 2ms/step
## Epoch 7/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 221ms/epoch - 3ms/step
## Epoch 8/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 198ms/epoch - 2ms/step
## Epoch 9/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 191ms/epoch - 2ms/step
## Epoch 10/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 184ms/epoch - 2ms/step
## Epoch 11/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 185ms/epoch - 2ms/step
## Epoch 12/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 200ms/epoch - 2ms/step
## Epoch 13/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 190ms/epoch - 2ms/step
## Epoch 14/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 192ms/epoch - 2ms/step
## Epoch 15/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 192ms/epoch - 2ms/step
## 5/5 - 0s - 77ms/epoch - 15ms/step
predicted_classes <- ifelse(predictions > 0.5, 1, 0)
# Evaluation metrics
confusion_matrix <- table(Predicted = predicted_classes, Actual = as.matrix(test_y))
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
sensitivity <- confusion_matrix["1","1"] / sum(confusion_matrix[, "1"]) # TP / (TP + FN)
specificity <- confusion_matrix["0","0"] / sum(confusion_matrix[, "0"]) # TN / (TN + FP)
odds_ratio <- (sensitivity / (1 - sensitivity)) / (specificity / (1 - specificity))
LOR <- ifelse(odds_ratio == 0, -Inf, ifelse(odds_ratio == Inf, Inf, log(odds_ratio)))
auc <- pROC::auc(pROC::roc(response = test_y, predictor = as.numeric(predictions)))## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Display results
list(
ConfusionMatrix = confusion_matrix,
Accuracy = accuracy,
Sensitivity = sensitivity,
Specificity = specificity,
OddsRatio = odds_ratio,
LOR = LOR,
AUC = auc
)## $ConfusionMatrix
## Actual
## Predicted 0 1
## 0 46 0
## 1 0 89
##
## $Accuracy
## [1] 1
##
## $Sensitivity
## [1] 1
##
## $Specificity
## [1] 1
##
## $OddsRatio
## [1] NaN
##
## $LOR
## [1] NA
##
## $AUC
## Area under the curve: 1
When comparing the results of the binary classification to the multi-class classification under the same conditions (15 epochs and a batch size of 5), the perfect metrics (100% accuracy and an AUC of 1) achieved at epoch 3 are reported by the binary classification model might indeed raise concerns about overfitting, or it could be an indication that the test set was not challenging enough or that the data was not representative of the real-world complexity.
Binary Classification Model Results Confusion Matrix: The matrix shows that the model predicted all the instances correctly. There are no false positives or false negatives.
Accuracy: The accuracy is 1 (or 100%), which means the model correctly classified all the instances in the test dataset.
Sensitivity: Also known as recall or true positive rate, is 1, indicating that the model identified all positive instances correctly.
Specificity: This is the true negative rate, and it is also 1, meaning all negative instances were correctly identified by the model.
Odds Ratio: This is not applicable (NaN) because there are no false positives or false negatives, leading to a division by zero in the calculation.
Log Odds Ratio (LOR): This is NA, which is a consequence of the odds ratio being NaN. Without false positives or false negatives, the odds ratio cannot be computed.
Area Under the Curve (AUC): With a value of 1, this indicates perfect discrimination by the model between the positive and negative classes.
## Warning: package 'ggplot2' was built under R version 4.2.3
# 1. Scatter Plot between MMSE and Age
ggplot(df, aes(x = Age, y = Group)) +
geom_point() +
labs(title = "Scatter Plot of Group vs Age", x = "Age", y = "MMSE")# 2. Boxplot of MMSE across different Groups
ggplot(df, aes(x = Group, y = Age)) +
geom_boxplot() +
labs(title = "Boxplot of Age for Different Groups", x = "Group", y = "AGE")# 3. Histogram of Age
ggplot(df, aes(x = Age)) +
geom_histogram(bins = 30) +
labs(title = "Histogram of Age", x = "Age", y = "Count")# 4. Density Plot of Age for AD vs NC
ggplot(df, aes(x = Age, fill = Group)) +
geom_density(alpha = 0.7) +
labs(title = "Density Plot of Age for Different Groups", x = "Age", y = "Density")# 5. Bar Plot for the count of different Groups
ggplot(df, aes(x = Group)) +
geom_bar() +
labs(title = "Count of Different Groups", x = "Group", y = "Count")# Barplot of predictions
# Create a dataframe for plotting
# Assuming 'predictions' is a vector with predicted classes
prediction_table <- table(predictions)
# Now we convert it into a data frame for plotting with ggplot2
prediction_df <- as.data.frame(prediction_table)
# Give proper names to the columns
names(prediction_df) <- c("Class", "Count")
# Use ggplot2 to create the bar graph
library(ggplot2)
ggplot(prediction_df, aes(x = Class, y = Count, fill = Class)) +
geom_bar(stat = "identity") + # "identity" to use the actual values in the 'Count' column
theme_minimal() +
labs(title = "Bar Graph of Predicted Classes", x = "Class", y = "Count") +
scale_fill_brewer(palette = "Set2") + # Optional: to use different colors for each class
scale_x_discrete(labels = c("0" = "NC", "1" = "Patients")) ## <pointer: 0x0>
The data indicates that individuals aged 80 and above are predominantly classified as patients, with a total count of 448, compared to 224 classified as non-cases (NC). Within the predicted classifications, there are 128 individuals identified as patients in contrast to 74 as non-cases (NC).
The model summary shows a sequential neural network with two layers:
dense_1: A fully connected layer with 8 neurons, outputting 11 features, with 96 parameters (11x 8 +8).
dense: The output layer with a single neuron, suitable for binary classification, with 9 parameters (8 x 1 +1).
The network has a total of 105 trainable parameters, indicating it’s a relatively simple model that will learn from the data during training. The activation functions used suggest the model is designed to capture non-linear patterns and output probabilities for binary classification.
set.seed(2024)
df2 <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
df2$Group <- as.numeric(as.factor(df2$Group))-1
df2$Measure <- as.numeric(as.factor(df2$Measure))
df2$Sex <- as.numeric(as.factor(df2$Sex))
features <- df2[,-c(1,2)]
target <- df2[, 2]
indices <- sample(1:nrow(features), size = 0.8 * nrow(features), replace = FALSE)
train_x <- features[indices, ]
train_y <- target[indices,]
test_x <- features[-indices, ]
test_y <- target[-indices,]
# Define the model for multi-class classification
model <- keras_model_sequential() %>%
layer_dense(units = 8, activation = 'relu', input_shape = c(ncol(train_x))) %>%
layer_dense(units = 3, activation = 'softmax') # Adjust for three classes
# Compile the model for multi-class classification
model %>% compile(
loss = 'categorical_crossentropy', # Change to categorical crossentropy
optimizer = 'adam',
metrics = 'accuracy'
)
# Convert 'train_y' and 'test_y' to one-hot encoded format
train_y_onehot <- to_categorical(as.matrix(train_y))
test_y_onehot <- to_categorical(as.matrix(test_y))
# Fit the model on one-hot encoded targets
history <- model %>% fit(
as.matrix(train_x),
train_y_onehot,
epochs = 15,
batch_size = 5,
validation_split = 0.2
)## Epoch 1/15
## 86/86 - 1s - loss: 621639.9375 - accuracy: 0.3333 - val_loss: 463797.7812 - val_accuracy: 0.3704 - 829ms/epoch - 10ms/step
## Epoch 2/15
## 86/86 - 0s - loss: 375609.9688 - accuracy: 0.3333 - val_loss: 260615.2344 - val_accuracy: 0.3704 - 219ms/epoch - 3ms/step
## Epoch 3/15
## 86/86 - 0s - loss: 162118.6562 - accuracy: 0.3333 - val_loss: 29400.5801 - val_accuracy: 0.3704 - 237ms/epoch - 3ms/step
## Epoch 4/15
## 86/86 - 0s - loss: 9215.8760 - accuracy: 0.2051 - val_loss: 5075.2012 - val_accuracy: 0.3704 - 194ms/epoch - 2ms/step
## Epoch 5/15
## 86/86 - 0s - loss: 3057.4426 - accuracy: 0.2308 - val_loss: 3777.0122 - val_accuracy: 0.6481 - 192ms/epoch - 2ms/step
## Epoch 6/15
## 86/86 - 0s - loss: 2795.2437 - accuracy: 0.2098 - val_loss: 2641.6807 - val_accuracy: 0.0000e+00 - 222ms/epoch - 3ms/step
## Epoch 7/15
## 86/86 - 0s - loss: 2609.0393 - accuracy: 0.1562 - val_loss: 2646.5647 - val_accuracy: 0.2963 - 204ms/epoch - 2ms/step
## Epoch 8/15
## 86/86 - 0s - loss: 2232.0449 - accuracy: 0.2051 - val_loss: 2059.3772 - val_accuracy: 0.3704 - 191ms/epoch - 2ms/step
## Epoch 9/15
## 86/86 - 0s - loss: 2008.0107 - accuracy: 0.1981 - val_loss: 2067.1299 - val_accuracy: 0.2870 - 195ms/epoch - 2ms/step
## Epoch 10/15
## 86/86 - 0s - loss: 1645.1154 - accuracy: 0.2168 - val_loss: 1322.6821 - val_accuracy: 0.0093 - 237ms/epoch - 3ms/step
## Epoch 11/15
## 86/86 - 0s - loss: 1513.8796 - accuracy: 0.2331 - val_loss: 1268.5308 - val_accuracy: 0.2778 - 191ms/epoch - 2ms/step
## Epoch 12/15
## 86/86 - 0s - loss: 1283.9553 - accuracy: 0.2471 - val_loss: 1954.7231 - val_accuracy: 0.3704 - 200ms/epoch - 2ms/step
## Epoch 13/15
## 86/86 - 0s - loss: 1103.4797 - accuracy: 0.3473 - val_loss: 979.8542 - val_accuracy: 0.3704 - 209ms/epoch - 2ms/step
## Epoch 14/15
## 86/86 - 0s - loss: 654.7548 - accuracy: 0.2984 - val_loss: 710.6323 - val_accuracy: 0.6019 - 238ms/epoch - 3ms/step
## Epoch 15/15
## 86/86 - 0s - loss: 454.2026 - accuracy: 0.4709 - val_loss: 174.0480 - val_accuracy: 0.3704 - 203ms/epoch - 2ms/step
# Predict on the test set with the multi-class model
predictions <- model %>% predict(as.matrix(test_x))## 5/5 - 0s - 49ms/epoch - 10ms/step
# Convert predictions to class labels
predicted_classes <- apply(predictions, 1, which.max) - 1
# Calculate the confusion matrix and other evaluation metrics
confusion_matrix <- table(Predicted = predicted_classes, Actual = as.matrix(test_y))
# Accuracy
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
# For multi-class AUC, you need to calculate AUC for each class
auc_values <- sapply(1:ncol(test_y_onehot), function(class_index) {
roc_obj <- pROC::roc(response = test_y_onehot[, class_index], predictor = predictions[, class_index])
pROC::auc(roc_obj)
})## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls > cases
## Setting levels: control = 0, case = 1
## Setting direction: controls > cases
# Average AUC
mean_auc <- mean(auc_values)
# Display results
list(
ConfusionMatrix = confusion_matrix,
Accuracy = accuracy,
AUC = mean_auc
)## $ConfusionMatrix
## Actual
## Predicted 0 1 2
## 0 41 48 0
## 1 0 0 46
##
## $Accuracy
## [1] 0.3037037
##
## $AUC
## [1] 0.6957535
## Model: "sequential_1"
## ________________________________________________________________________________
## Layer (type) Output Shape Param #
## ================================================================================
## dense_3 (Dense) (None, 8) 96
## dense_2 (Dense) (None, 3) 27
## ================================================================================
## Total params: 123
## Trainable params: 123
## Non-trainable params: 0
## ________________________________________________________________________________
Using the same training parameters ensures that the difference in performance is likely due to the increased complexity of the task rather than differences in the training process itself. As more classes are introduced, the model must learn to distinguish between a greater number of patterns, which typically makes the task more challenging and can lead to a decrease in performance metrics compared to a binary classification task.
Multi-Class Classification Model Results Confusion Matrix:
This is a 3x3 matrix, indicating three classes.
For ‘class 0’, the model predicted correctly 41 times, and there were no predictions for ‘class 2’ when ‘class 0’ was the actual class. However, it incorrectly predicted ‘class 1’ for 48 instances where ‘class 0’ was the actual class.
For ‘class 1’, there are no correct predictions, as the actual count for ‘class 1’ is not given, and it seems that all instances where ‘class 1’ was the actual class have been incorrectly predicted as ‘class 2’.
For ‘class 2’, the model correctly predicted 46 times, and there were no predictions for ‘class 0’ or ‘class 1’ when ‘class 2’ was the actual class.
The row for ‘class 2’ predictions is all zeros, which may indicate that there were no instances predicted as ‘class 2’, or the data for this row is missing.
Accuracy: The accuracy of the model is about 30.37% which is not perfect as in the binary case. This means that only 30.37% of the instances were correctly classified.
Area Under the Curve (AUC): The AUC value is 0.6957535, which is relatively higher than the accuracy and generally represents a good level of separability across the three classes.
The first dense layer, dense_3, has 8 neurons and presumably is connected to the input layer. Since the Param # is 96, this implies the input layer has 12 features (12 x 8 + 8).
The second dense layer, dense_2, has 3 neurons, corresponding to the 3 classes of the output with a total of 27 parameters indicate that it takes inputs from the 8 neurons of the previous layer (8 x 3 + 3).
In the context where the same number of epochs, batch size, and seed are used for training both the binary and multi-class classification models, it’s notable that the multi-class model performed inferiorly, with an accuracy of about 30.37% and an AUC of approximately 0.696. This is an expected outcome, as the complexity of the task tends to increase with the addition of more classes.
## Warning: package 'dplyr' was built under R version 4.2.3
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
url <- "https://wiki.socr.umich.edu/index.php/SOCR_Data_Dinov_032708_AllometricPlanRels"
web_data <- read_html(url)
df <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
head(df)## # A tibble: 6 × 8
## `Province/Sites` `Alt.(m)` `Long.(E,deg.)` `Lat.(N,deg.)` Born `L(g/no.)`
## <chr> <int> <dbl> <dbl> <chr> <dbl>
## 1 Heilongjiang 800 129. 44.3 natural 17538.
## 2 Heilongjiang 550 125. 52.3 natural 9313.
## 3 Heilongjiang 441 127. 51.7 natural 2570.
## 4 Heilongjiang 590 132. 46.5 natural 13939.
## 5 Heilongjiang 800 130. 44.1 natural 14375
## 6 Heilongjiang 590 125. 51.4 natural 9017.
## # ℹ 2 more variables: `M(g/no.)` <dbl>, `D(no./m2)` <dbl>
# Preprocess the data
df <- df %>%
mutate(across(where(is.character), ~factor(.x))) %>%
mutate(across(where(is.factor), ~as.numeric(as.factor(.x))))
# Verify the structure
str(df)## tibble [48 × 8] (S3: tbl_df/tbl/data.frame)
## $ Province/Sites: num [1:48] 1 1 1 1 1 1 1 1 2 2 ...
## $ Alt.(m) : int [1:48] 800 550 441 590 800 590 876 500 880 900 ...
## $ Long.(E,deg.) : num [1:48] 129 125 127 132 130 ...
## $ Lat.(N,deg.) : num [1:48] 44.3 52.3 51.7 46.5 44.1 ...
## $ Born : num [1:48] 1 1 1 1 1 1 1 1 1 1 ...
## $ L(g/no.) : num [1:48] 17538 9313 2570 13939 14375 ...
## $ M(g/no.) : num [1:48] 610990 298385 82175 422030 450643 ...
## $ D(no./m2) : num [1:48] 0.0394 0.0291 0.114 0.033 0.0544 ...
df <- df %>%
rename(
PS = `Province/Sites`,
Altitude = `Alt.(m)`,
Longitude = `Long.(E,deg.)`,
Latitude = `Lat.(N,deg.)`,
Born = Born,
Length= `L(g/no.)`,
Mass = `M(g/no.)`,
Density = `D(no./m2)`
)# Split the data into training and testing sets
set.seed(123) # for reproducibility
train_indices <- sample(1:nrow(df), size = 0.7 * nrow(df))
train_data <- df[train_indices, ]
test_data <- df[-train_indices, ]
#install.packages("neuralnet")
library(neuralnet)
# Define the model
model <- neuralnet(Density ~ . , data = train_data, hidden = c(7,2), linear.output = FALSE)
plot(model,rep = "best")The neural network is configured to predict Density using all available features in the train_data. It has a two-layered hidden structure with 7 neurons in the first layer and 2 neurons in the second layer. It does not output linear values, suggesting its use for classification. The error rate of 0.87 suggests the accuracy or fit of the best model (out of possibly many repetitions) to the training data. Whether this error rate is acceptable depends on the domain and the specific task. The training process took 217 steps to converge to the best solution, which suggests the complexity of finding the optimal weights for the given architecture and data.
# Now separate the X and Y variables
features <- df %>%
select(-Density) %>% # Exclude the response variable
select(where(is.numeric)) %>%
as.matrix()
response <- df$Density
train_indices <- sample(1:nrow(features), size = 0.7 * nrow(features))
train_features <- features[train_indices, ]
train_response <- response[train_indices]
test_features <- features[-train_indices, ]
test_response <- response[-train_indices]
# Define the model
model <- keras_model_sequential() %>%
layer_dense(units = 256, activation = "relu") %>%
layer_dense(units = 128, activation = "relu") %>%
layer_dense(units = 64, activation = "relu") %>%
# layer_dense(units = 16, activation = "relu") %>%
layer_dense(units = 1)
# Compile the model
model %>% compile(
loss = 'mse',
optimizer =optimizer_rmsprop(),
metrics = list("mean_absolute_error"))
# Fit the model on the training data
history <- model %>% fit(
as.matrix(train_features), as.matrix(train_response),
epochs = 200,
batch_size = 10,
validation_split = 0.2
)## Epoch 1/200
## 3/3 - 1s - loss: 2018269440.0000 - mean_absolute_error: 22177.6465 - val_loss: 19320790.0000 - val_mean_absolute_error: 3743.8501 - 731ms/epoch - 244ms/step
## Epoch 2/200
## 3/3 - 0s - loss: 81940584.0000 - mean_absolute_error: 5781.9619 - val_loss: 541667.3750 - val_mean_absolute_error: 612.5004 - 30ms/epoch - 10ms/step
## Epoch 3/200
## 3/3 - 0s - loss: 420525.3750 - mean_absolute_error: 378.3483 - val_loss: 16353.6895 - val_mean_absolute_error: 112.0617 - 31ms/epoch - 10ms/step
## Epoch 4/200
## 3/3 - 0s - loss: 53686.3828 - mean_absolute_error: 158.2263 - val_loss: 4885.2241 - val_mean_absolute_error: 60.9166 - 32ms/epoch - 11ms/step
## Epoch 5/200
## 3/3 - 0s - loss: 58311.3984 - mean_absolute_error: 96.0378 - val_loss: 1108321.8750 - val_mean_absolute_error: 909.3096 - 31ms/epoch - 10ms/step
## Epoch 6/200
## 3/3 - 0s - loss: 538943.4375 - mean_absolute_error: 429.0291 - val_loss: 279262.1562 - val_mean_absolute_error: 460.6824 - 32ms/epoch - 11ms/step
## Epoch 7/200
## 3/3 - 0s - loss: 136400.3281 - mean_absolute_error: 203.0796 - val_loss: 1333412.2500 - val_mean_absolute_error: 995.8418 - 30ms/epoch - 10ms/step
## Epoch 8/200
## 3/3 - 0s - loss: 27682180.0000 - mean_absolute_error: 3237.5864 - val_loss: 17711674.0000 - val_mean_absolute_error: 3588.7964 - 32ms/epoch - 11ms/step
## Epoch 9/200
## 3/3 - 0s - loss: 10073023.0000 - mean_absolute_error: 1971.4135 - val_loss: 378269.7188 - val_mean_absolute_error: 526.8839 - 32ms/epoch - 11ms/step
## Epoch 10/200
## 3/3 - 0s - loss: 6927231.5000 - mean_absolute_error: 1634.7599 - val_loss: 212869.8594 - val_mean_absolute_error: 390.4831 - 33ms/epoch - 11ms/step
## Epoch 11/200
## 3/3 - 0s - loss: 13129225.0000 - mean_absolute_error: 1915.7727 - val_loss: 117424000.0000 - val_mean_absolute_error: 9245.2441 - 34ms/epoch - 11ms/step
## Epoch 12/200
## 3/3 - 0s - loss: 558812672.0000 - mean_absolute_error: 13993.5732 - val_loss: 542833.8125 - val_mean_absolute_error: 650.6257 - 39ms/epoch - 13ms/step
## Epoch 13/200
## 3/3 - 0s - loss: 1684837.2500 - mean_absolute_error: 796.0423 - val_loss: 9776.2256 - val_mean_absolute_error: 81.7043 - 31ms/epoch - 10ms/step
## Epoch 14/200
## 3/3 - 0s - loss: 28271.7773 - mean_absolute_error: 111.4699 - val_loss: 7318.7710 - val_mean_absolute_error: 81.0536 - 32ms/epoch - 11ms/step
## Epoch 15/200
## 3/3 - 0s - loss: 6255.3887 - mean_absolute_error: 65.2097 - val_loss: 19603.3438 - val_mean_absolute_error: 133.5465 - 36ms/epoch - 12ms/step
## Epoch 16/200
## 3/3 - 0s - loss: 83476.8203 - mean_absolute_error: 187.1333 - val_loss: 2539.5422 - val_mean_absolute_error: 43.4944 - 36ms/epoch - 12ms/step
## Epoch 17/200
## 3/3 - 0s - loss: 2667.6514 - mean_absolute_error: 44.1748 - val_loss: 2941.7185 - val_mean_absolute_error: 46.2272 - 35ms/epoch - 12ms/step
## Epoch 18/200
## 3/3 - 0s - loss: 7185.3838 - mean_absolute_error: 60.9618 - val_loss: 1897.5573 - val_mean_absolute_error: 35.5266 - 34ms/epoch - 11ms/step
## Epoch 19/200
## 3/3 - 0s - loss: 27400.7930 - mean_absolute_error: 117.2747 - val_loss: 2352.3967 - val_mean_absolute_error: 39.3390 - 34ms/epoch - 11ms/step
## Epoch 20/200
## 3/3 - 0s - loss: 33816.3203 - mean_absolute_error: 103.3453 - val_loss: 538546.8750 - val_mean_absolute_error: 610.3983 - 32ms/epoch - 11ms/step
## Epoch 21/200
## 3/3 - 0s - loss: 33267608.0000 - mean_absolute_error: 3490.0210 - val_loss: 156138608.0000 - val_mean_absolute_error: 10675.3682 - 34ms/epoch - 11ms/step
## Epoch 22/200
## 3/3 - 0s - loss: 137009712.0000 - mean_absolute_error: 7531.5337 - val_loss: 271546624.0000 - val_mean_absolute_error: 14046.1436 - 33ms/epoch - 11ms/step
## Epoch 23/200
## 3/3 - 0s - loss: 151299984.0000 - mean_absolute_error: 8139.8408 - val_loss: 43819368.0000 - val_mean_absolute_error: 5633.1338 - 34ms/epoch - 11ms/step
## Epoch 24/200
## 3/3 - 0s - loss: 80227248.0000 - mean_absolute_error: 6075.7715 - val_loss: 421272.5000 - val_mean_absolute_error: 536.9413 - 34ms/epoch - 11ms/step
## Epoch 25/200
## 3/3 - 0s - loss: 339061.2188 - mean_absolute_error: 423.4252 - val_loss: 68973.0469 - val_mean_absolute_error: 205.8063 - 32ms/epoch - 11ms/step
## Epoch 26/200
## 3/3 - 0s - loss: 53296.1719 - mean_absolute_error: 170.2190 - val_loss: 1788.9784 - val_mean_absolute_error: 37.8697 - 33ms/epoch - 11ms/step
## Epoch 27/200
## 3/3 - 0s - loss: 3559.3306 - mean_absolute_error: 41.7822 - val_loss: 1289.1930 - val_mean_absolute_error: 33.4098 - 34ms/epoch - 11ms/step
## Epoch 28/200
## 3/3 - 0s - loss: 8887.9316 - mean_absolute_error: 50.1870 - val_loss: 154476.7500 - val_mean_absolute_error: 317.8371 - 33ms/epoch - 11ms/step
## Epoch 29/200
## 3/3 - 0s - loss: 63265.7461 - mean_absolute_error: 145.9257 - val_loss: 60971.0820 - val_mean_absolute_error: 224.6793 - 33ms/epoch - 11ms/step
## Epoch 30/200
## 3/3 - 0s - loss: 153932.9375 - mean_absolute_error: 272.0783 - val_loss: 2402967.2500 - val_mean_absolute_error: 1305.9403 - 33ms/epoch - 11ms/step
## Epoch 31/200
## 3/3 - 0s - loss: 13533996.0000 - mean_absolute_error: 2225.2485 - val_loss: 7977262.5000 - val_mean_absolute_error: 2427.6316 - 34ms/epoch - 11ms/step
## Epoch 32/200
## 3/3 - 0s - loss: 131131312.0000 - mean_absolute_error: 7583.3003 - val_loss: 24936084.0000 - val_mean_absolute_error: 4279.2773 - 36ms/epoch - 12ms/step
## Epoch 33/200
## 3/3 - 0s - loss: 82842440.0000 - mean_absolute_error: 5778.0786 - val_loss: 171484.3906 - val_mean_absolute_error: 369.3508 - 33ms/epoch - 11ms/step
## Epoch 34/200
## 3/3 - 0s - loss: 167384.1875 - mean_absolute_error: 303.4387 - val_loss: 78062.3672 - val_mean_absolute_error: 218.7873 - 43ms/epoch - 14ms/step
## Epoch 35/200
## 3/3 - 0s - loss: 239421.1406 - mean_absolute_error: 313.0955 - val_loss: 1579.3732 - val_mean_absolute_error: 35.3739 - 33ms/epoch - 11ms/step
## Epoch 36/200
## 3/3 - 0s - loss: 2630.9927 - mean_absolute_error: 41.4062 - val_loss: 2226.5078 - val_mean_absolute_error: 41.4763 - 34ms/epoch - 11ms/step
## Epoch 37/200
## 3/3 - 0s - loss: 2470.5667 - mean_absolute_error: 41.6119 - val_loss: 1694.5486 - val_mean_absolute_error: 36.7925 - 32ms/epoch - 11ms/step
## Epoch 38/200
## 3/3 - 0s - loss: 13412.1826 - mean_absolute_error: 76.4950 - val_loss: 10037.0137 - val_mean_absolute_error: 72.8492 - 32ms/epoch - 11ms/step
## Epoch 39/200
## 3/3 - 0s - loss: 59829.0391 - mean_absolute_error: 145.3856 - val_loss: 3701.6887 - val_mean_absolute_error: 58.1323 - 32ms/epoch - 11ms/step
## Epoch 40/200
## 3/3 - 0s - loss: 9159.4854 - mean_absolute_error: 60.3409 - val_loss: 540286.5625 - val_mean_absolute_error: 612.2853 - 31ms/epoch - 10ms/step
## Epoch 41/200
## 3/3 - 0s - loss: 26429024.0000 - mean_absolute_error: 2599.9722 - val_loss: 175104480.0000 - val_mean_absolute_error: 11307.8828 - 32ms/epoch - 11ms/step
## Epoch 42/200
## 3/3 - 0s - loss: 335284384.0000 - mean_absolute_error: 9616.6221 - val_loss: 1261259.3750 - val_mean_absolute_error: 939.9443 - 31ms/epoch - 10ms/step
## Epoch 43/200
## 3/3 - 0s - loss: 1009645.6875 - mean_absolute_error: 652.6781 - val_loss: 11740.5420 - val_mean_absolute_error: 103.9871 - 31ms/epoch - 10ms/step
## Epoch 44/200
## 3/3 - 0s - loss: 9178.7012 - mean_absolute_error: 69.1061 - val_loss: 2743.8745 - val_mean_absolute_error: 47.3708 - 33ms/epoch - 11ms/step
## Epoch 45/200
## 3/3 - 0s - loss: 1828.2194 - mean_absolute_error: 38.1877 - val_loss: 1964.4827 - val_mean_absolute_error: 38.1303 - 32ms/epoch - 11ms/step
## Epoch 46/200
## 3/3 - 0s - loss: 3021.6558 - mean_absolute_error: 42.4684 - val_loss: 1528.6154 - val_mean_absolute_error: 34.9692 - 36ms/epoch - 12ms/step
## Epoch 47/200
## 3/3 - 0s - loss: 2329.2778 - mean_absolute_error: 38.0774 - val_loss: 36188.1445 - val_mean_absolute_error: 150.5165 - 30ms/epoch - 10ms/step
## Epoch 48/200
## 3/3 - 0s - loss: 537086.6250 - mean_absolute_error: 508.8421 - val_loss: 646195.9375 - val_mean_absolute_error: 701.8254 - 30ms/epoch - 10ms/step
## Epoch 49/200
## 3/3 - 0s - loss: 9328359.0000 - mean_absolute_error: 2050.6536 - val_loss: 8073240.0000 - val_mean_absolute_error: 2441.4158 - 33ms/epoch - 11ms/step
## Epoch 50/200
## 3/3 - 0s - loss: 13089812.0000 - mean_absolute_error: 2482.0381 - val_loss: 218663.0781 - val_mean_absolute_error: 384.3835 - 32ms/epoch - 11ms/step
## Epoch 51/200
## 3/3 - 0s - loss: 2962952.5000 - mean_absolute_error: 1125.8254 - val_loss: 513172.1562 - val_mean_absolute_error: 596.6017 - 34ms/epoch - 11ms/step
## Epoch 52/200
## 3/3 - 0s - loss: 3145223.5000 - mean_absolute_error: 1027.4364 - val_loss: 24655.3945 - val_mean_absolute_error: 115.7699 - 45ms/epoch - 15ms/step
## Epoch 53/200
## 3/3 - 0s - loss: 46877.7539 - mean_absolute_error: 131.2449 - val_loss: 358488.5625 - val_mean_absolute_error: 526.1915 - 32ms/epoch - 11ms/step
## Epoch 54/200
## 3/3 - 0s - loss: 2535650.0000 - mean_absolute_error: 838.7695 - val_loss: 92304952.0000 - val_mean_absolute_error: 8181.4019 - 41ms/epoch - 14ms/step
## Epoch 55/200
## 3/3 - 0s - loss: 153583280.0000 - mean_absolute_error: 8465.3076 - val_loss: 32719.2148 - val_mean_absolute_error: 134.6143 - 35ms/epoch - 12ms/step
## Epoch 56/200
## 3/3 - 0s - loss: 101749.2500 - mean_absolute_error: 224.0085 - val_loss: 134643.0781 - val_mean_absolute_error: 326.8233 - 30ms/epoch - 10ms/step
## Epoch 57/200
## 3/3 - 0s - loss: 1137703.5000 - mean_absolute_error: 743.5020 - val_loss: 5621.4644 - val_mean_absolute_error: 56.5170 - 37ms/epoch - 12ms/step
## Epoch 58/200
## 3/3 - 0s - loss: 31622.2832 - mean_absolute_error: 92.0486 - val_loss: 416698.0000 - val_mean_absolute_error: 538.0172 - 36ms/epoch - 12ms/step
## Epoch 59/200
## 3/3 - 0s - loss: 2595974.2500 - mean_absolute_error: 1108.1090 - val_loss: 60489.6055 - val_mean_absolute_error: 219.5775 - 36ms/epoch - 12ms/step
## Epoch 60/200
## 3/3 - 0s - loss: 2171765.7500 - mean_absolute_error: 932.6613 - val_loss: 4797.9722 - val_mean_absolute_error: 52.7236 - 39ms/epoch - 13ms/step
## Epoch 61/200
## 3/3 - 0s - loss: 12954.5361 - mean_absolute_error: 75.0471 - val_loss: 203219.3594 - val_mean_absolute_error: 371.1813 - 39ms/epoch - 13ms/step
## Epoch 62/200
## 3/3 - 0s - loss: 151501.2969 - mean_absolute_error: 210.9671 - val_loss: 6151448.5000 - val_mean_absolute_error: 2103.2278 - 36ms/epoch - 12ms/step
## Epoch 63/200
## 3/3 - 0s - loss: 52609704.0000 - mean_absolute_error: 4104.1826 - val_loss: 60516196.0000 - val_mean_absolute_error: 6645.6826 - 39ms/epoch - 13ms/step
## Epoch 64/200
## 3/3 - 0s - loss: 118148256.0000 - mean_absolute_error: 5852.3223 - val_loss: 153369.1719 - val_mean_absolute_error: 324.0020 - 39ms/epoch - 13ms/step
## Epoch 65/200
## 3/3 - 0s - loss: 212744.6094 - mean_absolute_error: 284.8939 - val_loss: 1139.6920 - val_mean_absolute_error: 28.2405 - 35ms/epoch - 12ms/step
## Epoch 66/200
## 3/3 - 0s - loss: 18994.6133 - mean_absolute_error: 74.8744 - val_loss: 12831.0654 - val_mean_absolute_error: 84.6988 - 31ms/epoch - 10ms/step
## Epoch 67/200
## 3/3 - 0s - loss: 4378.7480 - mean_absolute_error: 37.4891 - val_loss: 898.2761 - val_mean_absolute_error: 22.0822 - 31ms/epoch - 10ms/step
## Epoch 68/200
## 3/3 - 0s - loss: 62395.3203 - mean_absolute_error: 143.6369 - val_loss: 3025.7419 - val_mean_absolute_error: 41.3140 - 34ms/epoch - 11ms/step
## Epoch 69/200
## 3/3 - 0s - loss: 23052.8027 - mean_absolute_error: 85.8221 - val_loss: 10799.8389 - val_mean_absolute_error: 95.0174 - 30ms/epoch - 10ms/step
## Epoch 70/200
## 3/3 - 0s - loss: 192267.9688 - mean_absolute_error: 261.1484 - val_loss: 1507.6796 - val_mean_absolute_error: 33.5813 - 31ms/epoch - 10ms/step
## Epoch 71/200
## 3/3 - 0s - loss: 186815.0781 - mean_absolute_error: 276.8752 - val_loss: 394451.9375 - val_mean_absolute_error: 525.7282 - 33ms/epoch - 11ms/step
## Epoch 72/200
## 3/3 - 0s - loss: 132560.9219 - mean_absolute_error: 175.5919 - val_loss: 39011.6055 - val_mean_absolute_error: 176.7799 - 32ms/epoch - 11ms/step
## Epoch 73/200
## 3/3 - 0s - loss: 39941.3281 - mean_absolute_error: 108.4675 - val_loss: 2248049.0000 - val_mean_absolute_error: 1263.6619 - 30ms/epoch - 10ms/step
## Epoch 74/200
## 3/3 - 0s - loss: 22744330.0000 - mean_absolute_error: 3119.4399 - val_loss: 19607658.0000 - val_mean_absolute_error: 3777.7512 - 36ms/epoch - 12ms/step
## Epoch 75/200
## 3/3 - 0s - loss: 87996248.0000 - mean_absolute_error: 5791.5669 - val_loss: 2148677.5000 - val_mean_absolute_error: 1241.1272 - 34ms/epoch - 11ms/step
## Epoch 76/200
## 3/3 - 0s - loss: 2065985.3750 - mean_absolute_error: 926.6773 - val_loss: 12976.6318 - val_mean_absolute_error: 105.3063 - 33ms/epoch - 11ms/step
## Epoch 77/200
## 3/3 - 0s - loss: 14628.1699 - mean_absolute_error: 88.9185 - val_loss: 32803.6523 - val_mean_absolute_error: 142.9313 - 31ms/epoch - 10ms/step
## Epoch 78/200
## 3/3 - 0s - loss: 40334.0781 - mean_absolute_error: 130.8246 - val_loss: 179863.0781 - val_mean_absolute_error: 371.7028 - 31ms/epoch - 10ms/step
## Epoch 79/200
## 3/3 - 0s - loss: 3286009.0000 - mean_absolute_error: 1184.0558 - val_loss: 1872694.3750 - val_mean_absolute_error: 1179.7385 - 32ms/epoch - 11ms/step
## Epoch 80/200
## 3/3 - 0s - loss: 14118857.0000 - mean_absolute_error: 2438.2290 - val_loss: 480.7404 - val_mean_absolute_error: 18.7436 - 32ms/epoch - 11ms/step
## Epoch 81/200
## 3/3 - 0s - loss: 1286.7532 - mean_absolute_error: 21.5733 - val_loss: 23828.2285 - val_mean_absolute_error: 122.2422 - 31ms/epoch - 10ms/step
## Epoch 82/200
## 3/3 - 0s - loss: 24384.1387 - mean_absolute_error: 97.0255 - val_loss: 841.4240 - val_mean_absolute_error: 23.9780 - 31ms/epoch - 10ms/step
## Epoch 83/200
## 3/3 - 0s - loss: 51228.5977 - mean_absolute_error: 118.8286 - val_loss: 22113.5039 - val_mean_absolute_error: 117.0461 - 31ms/epoch - 10ms/step
## Epoch 84/200
## 3/3 - 0s - loss: 607789.7500 - mean_absolute_error: 446.8210 - val_loss: 1142107.0000 - val_mean_absolute_error: 902.0145 - 32ms/epoch - 11ms/step
## Epoch 85/200
## 3/3 - 0s - loss: 25698574.0000 - mean_absolute_error: 3298.6118 - val_loss: 3318258.2500 - val_mean_absolute_error: 1546.1804 - 30ms/epoch - 10ms/step
## Epoch 86/200
## 3/3 - 0s - loss: 17519016.0000 - mean_absolute_error: 2747.0378 - val_loss: 292000.4375 - val_mean_absolute_error: 467.0139 - 34ms/epoch - 11ms/step
## Epoch 87/200
## 3/3 - 0s - loss: 179039.9375 - mean_absolute_error: 260.4535 - val_loss: 3622.5886 - val_mean_absolute_error: 44.2102 - 31ms/epoch - 10ms/step
## Epoch 88/200
## 3/3 - 0s - loss: 6280.5166 - mean_absolute_error: 37.8223 - val_loss: 71546.2969 - val_mean_absolute_error: 221.4520 - 33ms/epoch - 11ms/step
## Epoch 89/200
## 3/3 - 0s - loss: 436953.2812 - mean_absolute_error: 452.0117 - val_loss: 29754.5488 - val_mean_absolute_error: 153.0072 - 30ms/epoch - 10ms/step
## Epoch 90/200
## 3/3 - 0s - loss: 718396.6250 - mean_absolute_error: 572.8501 - val_loss: 256462.1406 - val_mean_absolute_error: 437.9351 - 31ms/epoch - 10ms/step
## Epoch 91/200
## 3/3 - 0s - loss: 318797.8125 - mean_absolute_error: 343.6335 - val_loss: 5266094.5000 - val_mean_absolute_error: 1951.3035 - 31ms/epoch - 10ms/step
## Epoch 92/200
## 3/3 - 0s - loss: 30799170.0000 - mean_absolute_error: 3511.8857 - val_loss: 67986.2500 - val_mean_absolute_error: 227.7242 - 30ms/epoch - 10ms/step
## Epoch 93/200
## 3/3 - 0s - loss: 42616.9375 - mean_absolute_error: 137.6563 - val_loss: 23887.8477 - val_mean_absolute_error: 136.8813 - 31ms/epoch - 10ms/step
## Epoch 94/200
## 3/3 - 0s - loss: 24457.0449 - mean_absolute_error: 109.4135 - val_loss: 18395.9141 - val_mean_absolute_error: 109.7767 - 30ms/epoch - 10ms/step
## Epoch 95/200
## 3/3 - 0s - loss: 238750.5938 - mean_absolute_error: 320.8312 - val_loss: 13683.7412 - val_mean_absolute_error: 93.7336 - 33ms/epoch - 11ms/step
## Epoch 96/200
## 3/3 - 0s - loss: 228422.9062 - mean_absolute_error: 307.6782 - val_loss: 1174.6683 - val_mean_absolute_error: 33.0905 - 30ms/epoch - 10ms/step
## Epoch 97/200
## 3/3 - 0s - loss: 588.8242 - mean_absolute_error: 18.8131 - val_loss: 5311.0234 - val_mean_absolute_error: 55.8430 - 31ms/epoch - 10ms/step
## Epoch 98/200
## 3/3 - 0s - loss: 189273.0000 - mean_absolute_error: 278.3611 - val_loss: 319843.7500 - val_mean_absolute_error: 476.3712 - 29ms/epoch - 10ms/step
## Epoch 99/200
## 3/3 - 0s - loss: 1318100.7500 - mean_absolute_error: 624.2157 - val_loss: 31753366.0000 - val_mean_absolute_error: 4809.6646 - 29ms/epoch - 10ms/step
## Epoch 100/200
## 3/3 - 0s - loss: 30514254.0000 - mean_absolute_error: 3658.4478 - val_loss: 84541.1094 - val_mean_absolute_error: 250.1991 - 35ms/epoch - 12ms/step
## Epoch 101/200
## 3/3 - 0s - loss: 250885.2188 - mean_absolute_error: 311.9144 - val_loss: 6944141.5000 - val_mean_absolute_error: 2252.7327 - 29ms/epoch - 10ms/step
## Epoch 102/200
## 3/3 - 0s - loss: 19719768.0000 - mean_absolute_error: 3020.4045 - val_loss: 2167203.2500 - val_mean_absolute_error: 1256.7057 - 30ms/epoch - 10ms/step
## Epoch 103/200
## 3/3 - 0s - loss: 3201344.5000 - mean_absolute_error: 1230.4966 - val_loss: 227475.4688 - val_mean_absolute_error: 404.6143 - 29ms/epoch - 10ms/step
## Epoch 104/200
## 3/3 - 0s - loss: 126957.7109 - mean_absolute_error: 228.2343 - val_loss: 304632.5625 - val_mean_absolute_error: 468.4704 - 34ms/epoch - 11ms/step
## Epoch 105/200
## 3/3 - 0s - loss: 468827.9375 - mean_absolute_error: 446.0646 - val_loss: 164239.7188 - val_mean_absolute_error: 347.5965 - 32ms/epoch - 11ms/step
## Epoch 106/200
## 3/3 - 0s - loss: 404883.7188 - mean_absolute_error: 412.0231 - val_loss: 249431.7500 - val_mean_absolute_error: 428.1185 - 39ms/epoch - 13ms/step
## Epoch 107/200
## 3/3 - 0s - loss: 852685.1250 - mean_absolute_error: 631.8765 - val_loss: 1696.3633 - val_mean_absolute_error: 33.6871 - 33ms/epoch - 11ms/step
## Epoch 108/200
## 3/3 - 0s - loss: 30260.6328 - mean_absolute_error: 88.6826 - val_loss: 68908.2969 - val_mean_absolute_error: 224.5719 - 30ms/epoch - 10ms/step
## Epoch 109/200
## 3/3 - 0s - loss: 2290737.5000 - mean_absolute_error: 1006.7604 - val_loss: 409285.0000 - val_mean_absolute_error: 547.3334 - 32ms/epoch - 11ms/step
## Epoch 110/200
## 3/3 - 0s - loss: 1412532.0000 - mean_absolute_error: 878.8727 - val_loss: 6883019.5000 - val_mean_absolute_error: 2234.8013 - 31ms/epoch - 10ms/step
## Epoch 111/200
## 3/3 - 0s - loss: 25474850.0000 - mean_absolute_error: 3365.6807 - val_loss: 990808.5625 - val_mean_absolute_error: 848.2744 - 35ms/epoch - 12ms/step
## Epoch 112/200
## 3/3 - 0s - loss: 2909528.0000 - mean_absolute_error: 1077.2312 - val_loss: 22265.8613 - val_mean_absolute_error: 124.6115 - 33ms/epoch - 11ms/step
## Epoch 113/200
## 3/3 - 0s - loss: 196962.7656 - mean_absolute_error: 299.8185 - val_loss: 722.6653 - val_mean_absolute_error: 16.8054 - 36ms/epoch - 12ms/step
## Epoch 114/200
## 3/3 - 0s - loss: 38334.3281 - mean_absolute_error: 112.4606 - val_loss: 3395.5334 - val_mean_absolute_error: 47.4295 - 32ms/epoch - 11ms/step
## Epoch 115/200
## 3/3 - 0s - loss: 31961.6758 - mean_absolute_error: 106.6673 - val_loss: 25176.9102 - val_mean_absolute_error: 133.5696 - 35ms/epoch - 12ms/step
## Epoch 116/200
## 3/3 - 0s - loss: 361894.3125 - mean_absolute_error: 382.7624 - val_loss: 3150.1011 - val_mean_absolute_error: 45.7732 - 32ms/epoch - 11ms/step
## Epoch 117/200
## 3/3 - 0s - loss: 11032.8213 - mean_absolute_error: 46.2494 - val_loss: 286396.5625 - val_mean_absolute_error: 455.3769 - 35ms/epoch - 12ms/step
## Epoch 118/200
## 3/3 - 0s - loss: 4074568.0000 - mean_absolute_error: 1239.4202 - val_loss: 87950.2578 - val_mean_absolute_error: 252.1011 - 36ms/epoch - 12ms/step
## Epoch 119/200
## 3/3 - 0s - loss: 343228.4375 - mean_absolute_error: 354.1418 - val_loss: 385445.0625 - val_mean_absolute_error: 533.0853 - 35ms/epoch - 12ms/step
## Epoch 120/200
## 3/3 - 0s - loss: 773345.0625 - mean_absolute_error: 587.4399 - val_loss: 68679.8750 - val_mean_absolute_error: 218.9030 - 32ms/epoch - 11ms/step
## Epoch 121/200
## 3/3 - 0s - loss: 956748.0000 - mean_absolute_error: 672.0156 - val_loss: 451270.4375 - val_mean_absolute_error: 577.6586 - 33ms/epoch - 11ms/step
## Epoch 122/200
## 3/3 - 0s - loss: 4663699.0000 - mean_absolute_error: 1507.8859 - val_loss: 1419493.6250 - val_mean_absolute_error: 1014.4671 - 30ms/epoch - 10ms/step
## Epoch 123/200
## 3/3 - 0s - loss: 5992500.5000 - mean_absolute_error: 1631.0907 - val_loss: 430.0992 - val_mean_absolute_error: 14.4033 - 31ms/epoch - 10ms/step
## Epoch 124/200
## 3/3 - 0s - loss: 17791.3828 - mean_absolute_error: 83.1656 - val_loss: 726.9236 - val_mean_absolute_error: 24.7913 - 31ms/epoch - 10ms/step
## Epoch 125/200
## 3/3 - 0s - loss: 5056.4185 - mean_absolute_error: 33.7889 - val_loss: 91524.8203 - val_mean_absolute_error: 258.2952 - 30ms/epoch - 10ms/step
## Epoch 126/200
## 3/3 - 0s - loss: 354237.5000 - mean_absolute_error: 366.8130 - val_loss: 467.4528 - val_mean_absolute_error: 15.9974 - 33ms/epoch - 11ms/step
## Epoch 127/200
## 3/3 - 0s - loss: 5993.3979 - mean_absolute_error: 28.4473 - val_loss: 166245.8750 - val_mean_absolute_error: 349.1440 - 30ms/epoch - 10ms/step
## Epoch 128/200
## 3/3 - 0s - loss: 1577939.1250 - mean_absolute_error: 868.7360 - val_loss: 982245.6250 - val_mean_absolute_error: 841.0688 - 32ms/epoch - 11ms/step
## Epoch 129/200
## 3/3 - 0s - loss: 897370.7500 - mean_absolute_error: 655.7756 - val_loss: 229564.4688 - val_mean_absolute_error: 403.7808 - 30ms/epoch - 10ms/step
## Epoch 130/200
## 3/3 - 0s - loss: 237460.9062 - mean_absolute_error: 341.0848 - val_loss: 599003.4375 - val_mean_absolute_error: 655.3722 - 31ms/epoch - 10ms/step
## Epoch 131/200
## 3/3 - 0s - loss: 5432268.5000 - mean_absolute_error: 1429.5894 - val_loss: 137.6288 - val_mean_absolute_error: 7.0437 - 32ms/epoch - 11ms/step
## Epoch 132/200
## 3/3 - 0s - loss: 175554.2344 - mean_absolute_error: 255.8757 - val_loss: 376282.5938 - val_mean_absolute_error: 525.5138 - 30ms/epoch - 10ms/step
## Epoch 133/200
## 3/3 - 0s - loss: 1446243.7500 - mean_absolute_error: 799.0149 - val_loss: 425756.7188 - val_mean_absolute_error: 560.5094 - 32ms/epoch - 11ms/step
## Epoch 134/200
## 3/3 - 0s - loss: 320884.7188 - mean_absolute_error: 403.3936 - val_loss: 13851.1074 - val_mean_absolute_error: 97.5690 - 31ms/epoch - 10ms/step
## Epoch 135/200
## 3/3 - 0s - loss: 26910.2676 - mean_absolute_error: 106.1822 - val_loss: 116.1909 - val_mean_absolute_error: 7.6214 - 31ms/epoch - 10ms/step
## Epoch 136/200
## 3/3 - 0s - loss: 9848.1035 - mean_absolute_error: 61.9568 - val_loss: 187.4280 - val_mean_absolute_error: 11.3721 - 47ms/epoch - 16ms/step
## Epoch 137/200
## 3/3 - 0s - loss: 158293.2344 - mean_absolute_error: 186.6511 - val_loss: 1304788.8750 - val_mean_absolute_error: 973.9418 - 32ms/epoch - 11ms/step
## Epoch 138/200
## 3/3 - 0s - loss: 1808835.7500 - mean_absolute_error: 902.4971 - val_loss: 559732.0625 - val_mean_absolute_error: 643.6171 - 30ms/epoch - 10ms/step
## Epoch 139/200
## 3/3 - 0s - loss: 354059.7188 - mean_absolute_error: 345.5819 - val_loss: 1455.9567 - val_mean_absolute_error: 32.4863 - 29ms/epoch - 10ms/step
## Epoch 140/200
## 3/3 - 0s - loss: 1252.9559 - mean_absolute_error: 23.3614 - val_loss: 91.0481 - val_mean_absolute_error: 8.6962 - 32ms/epoch - 11ms/step
## Epoch 141/200
## 3/3 - 0s - loss: 509.8655 - mean_absolute_error: 14.8420 - val_loss: 42.7837 - val_mean_absolute_error: 6.0504 - 37ms/epoch - 12ms/step
## Epoch 142/200
## 3/3 - 0s - loss: 1818.5392 - mean_absolute_error: 22.9417 - val_loss: 4183.9619 - val_mean_absolute_error: 55.2590 - 33ms/epoch - 11ms/step
## Epoch 143/200
## 3/3 - 0s - loss: 9175.0332 - mean_absolute_error: 70.0252 - val_loss: 390.5165 - val_mean_absolute_error: 16.2214 - 33ms/epoch - 11ms/step
## Epoch 144/200
## 3/3 - 0s - loss: 833.9427 - mean_absolute_error: 20.5952 - val_loss: 52.4098 - val_mean_absolute_error: 6.6154 - 34ms/epoch - 11ms/step
## Epoch 145/200
## 3/3 - 0s - loss: 5005.1592 - mean_absolute_error: 31.6098 - val_loss: 19844.3848 - val_mean_absolute_error: 120.1162 - 33ms/epoch - 11ms/step
## Epoch 146/200
## 3/3 - 0s - loss: 42157.9297 - mean_absolute_error: 130.1169 - val_loss: 576586.0000 - val_mean_absolute_error: 647.2552 - 36ms/epoch - 12ms/step
## Epoch 147/200
## 3/3 - 0s - loss: 1487011.7500 - mean_absolute_error: 745.7644 - val_loss: 93213.2500 - val_mean_absolute_error: 260.9442 - 32ms/epoch - 11ms/step
## Epoch 148/200
## 3/3 - 0s - loss: 61970.7656 - mean_absolute_error: 116.1210 - val_loss: 1334.2076 - val_mean_absolute_error: 29.7729 - 33ms/epoch - 11ms/step
## Epoch 149/200
## 3/3 - 0s - loss: 3055.3159 - mean_absolute_error: 36.7403 - val_loss: 6.8815 - val_mean_absolute_error: 2.4467 - 34ms/epoch - 11ms/step
## Epoch 150/200
## 3/3 - 0s - loss: 35.9567 - mean_absolute_error: 3.0652 - val_loss: 379.9266 - val_mean_absolute_error: 15.2643 - 35ms/epoch - 12ms/step
## Epoch 151/200
## 3/3 - 0s - loss: 1782.0988 - mean_absolute_error: 27.5929 - val_loss: 11.5144 - val_mean_absolute_error: 2.8848 - 33ms/epoch - 11ms/step
## Epoch 152/200
## 3/3 - 0s - loss: 23.7713 - mean_absolute_error: 2.9590 - val_loss: 6.7567 - val_mean_absolute_error: 2.4171 - 33ms/epoch - 11ms/step
## Epoch 153/200
## 3/3 - 0s - loss: 234.0905 - mean_absolute_error: 9.8754 - val_loss: 102.8718 - val_mean_absolute_error: 7.3147 - 30ms/epoch - 10ms/step
## Epoch 154/200
## 3/3 - 0s - loss: 531.2399 - mean_absolute_error: 16.2913 - val_loss: 1470.8722 - val_mean_absolute_error: 33.6668 - 30ms/epoch - 10ms/step
## Epoch 155/200
## 3/3 - 0s - loss: 12457.1816 - mean_absolute_error: 61.9637 - val_loss: 211658.5781 - val_mean_absolute_error: 393.5392 - 31ms/epoch - 10ms/step
## Epoch 156/200
## 3/3 - 0s - loss: 1907337.3750 - mean_absolute_error: 925.4459 - val_loss: 389233.0000 - val_mean_absolute_error: 530.5489 - 30ms/epoch - 10ms/step
## Epoch 157/200
## 3/3 - 0s - loss: 221213.4844 - mean_absolute_error: 227.9880 - val_loss: 10.1331 - val_mean_absolute_error: 2.2994 - 32ms/epoch - 11ms/step
## Epoch 158/200
## 3/3 - 0s - loss: 1381.6731 - mean_absolute_error: 20.9561 - val_loss: 275.7252 - val_mean_absolute_error: 14.4810 - 33ms/epoch - 11ms/step
## Epoch 159/200
## 3/3 - 0s - loss: 1689.3594 - mean_absolute_error: 28.9907 - val_loss: 1448.9662 - val_mean_absolute_error: 31.7639 - 32ms/epoch - 11ms/step
## Epoch 160/200
## 3/3 - 0s - loss: 18160.9473 - mean_absolute_error: 91.7535 - val_loss: 14302.5801 - val_mean_absolute_error: 102.2473 - 31ms/epoch - 10ms/step
## Epoch 161/200
## 3/3 - 0s - loss: 229392.5000 - mean_absolute_error: 283.3499 - val_loss: 47442.5547 - val_mean_absolute_error: 184.7773 - 33ms/epoch - 11ms/step
## Epoch 162/200
## 3/3 - 0s - loss: 597554.3750 - mean_absolute_error: 509.5569 - val_loss: 177605.3594 - val_mean_absolute_error: 357.1114 - 32ms/epoch - 11ms/step
## Epoch 163/200
## 3/3 - 0s - loss: 441195.3438 - mean_absolute_error: 429.1706 - val_loss: 180504.2656 - val_mean_absolute_error: 364.0522 - 32ms/epoch - 11ms/step
## Epoch 164/200
## 3/3 - 0s - loss: 963219.8750 - mean_absolute_error: 605.6765 - val_loss: 504390.6562 - val_mean_absolute_error: 604.7960 - 32ms/epoch - 11ms/step
## Epoch 165/200
## 3/3 - 0s - loss: 1264254.6250 - mean_absolute_error: 674.0167 - val_loss: 111704.1875 - val_mean_absolute_error: 284.3521 - 30ms/epoch - 10ms/step
## Epoch 166/200
## 3/3 - 0s - loss: 203982.6719 - mean_absolute_error: 219.2178 - val_loss: 5497.4868 - val_mean_absolute_error: 63.5537 - 33ms/epoch - 11ms/step
## Epoch 167/200
## 3/3 - 0s - loss: 10977.6191 - mean_absolute_error: 53.7335 - val_loss: 8.0237 - val_mean_absolute_error: 2.2376 - 32ms/epoch - 11ms/step
## Epoch 168/200
## 3/3 - 0s - loss: 8.8539 - mean_absolute_error: 1.7761 - val_loss: 1.9708 - val_mean_absolute_error: 1.1796 - 31ms/epoch - 10ms/step
## Epoch 169/200
## 3/3 - 0s - loss: 7.0644 - mean_absolute_error: 1.9370 - val_loss: 2.9183 - val_mean_absolute_error: 1.2682 - 31ms/epoch - 10ms/step
## Epoch 170/200
## 3/3 - 0s - loss: 8.0536 - mean_absolute_error: 1.7811 - val_loss: 2.0759 - val_mean_absolute_error: 1.2561 - 30ms/epoch - 10ms/step
## Epoch 171/200
## 3/3 - 0s - loss: 2.9380 - mean_absolute_error: 1.4808 - val_loss: 2.1530 - val_mean_absolute_error: 1.2868 - 32ms/epoch - 11ms/step
## Epoch 172/200
## 3/3 - 0s - loss: 2.5209 - mean_absolute_error: 1.3521 - val_loss: 2.0810 - val_mean_absolute_error: 1.2635 - 31ms/epoch - 10ms/step
## Epoch 173/200
## 3/3 - 0s - loss: 2.5089 - mean_absolute_error: 1.3470 - val_loss: 2.9268 - val_mean_absolute_error: 1.3856 - 32ms/epoch - 11ms/step
## Epoch 174/200
## 3/3 - 0s - loss: 3.8935 - mean_absolute_error: 1.5365 - val_loss: 2.5146 - val_mean_absolute_error: 1.3422 - 30ms/epoch - 10ms/step
## Epoch 175/200
## 3/3 - 0s - loss: 3.3663 - mean_absolute_error: 1.4674 - val_loss: 7.1548 - val_mean_absolute_error: 2.3942 - 31ms/epoch - 10ms/step
## Epoch 176/200
## 3/3 - 0s - loss: 77.6947 - mean_absolute_error: 6.1288 - val_loss: 70.7006 - val_mean_absolute_error: 6.5913 - 39ms/epoch - 13ms/step
## Epoch 177/200
## 3/3 - 0s - loss: 1018.6829 - mean_absolute_error: 19.9167 - val_loss: 60.3742 - val_mean_absolute_error: 6.9452 - 41ms/epoch - 14ms/step
## Epoch 178/200
## 3/3 - 0s - loss: 20.1124 - mean_absolute_error: 2.9692 - val_loss: 58.3268 - val_mean_absolute_error: 6.8251 - 36ms/epoch - 12ms/step
## Epoch 179/200
## 3/3 - 0s - loss: 2751.5171 - mean_absolute_error: 31.1079 - val_loss: 281.9444 - val_mean_absolute_error: 14.5143 - 38ms/epoch - 13ms/step
## Epoch 180/200
## 3/3 - 0s - loss: 45005.2148 - mean_absolute_error: 129.5238 - val_loss: 56452.9102 - val_mean_absolute_error: 201.4269 - 44ms/epoch - 15ms/step
## Epoch 181/200
## 3/3 - 0s - loss: 784861.6875 - mean_absolute_error: 561.1381 - val_loss: 258325.4688 - val_mean_absolute_error: 434.5649 - 42ms/epoch - 14ms/step
## Epoch 182/200
## 3/3 - 0s - loss: 1097230.2500 - mean_absolute_error: 687.8195 - val_loss: 17444.1426 - val_mean_absolute_error: 111.0428 - 42ms/epoch - 14ms/step
## Epoch 183/200
## 3/3 - 0s - loss: 20123.1934 - mean_absolute_error: 102.1124 - val_loss: 2204.1211 - val_mean_absolute_error: 41.0962 - 37ms/epoch - 12ms/step
## Epoch 184/200
## 3/3 - 0s - loss: 3876.1802 - mean_absolute_error: 42.1887 - val_loss: 17462.6387 - val_mean_absolute_error: 116.1515 - 34ms/epoch - 11ms/step
## Epoch 185/200
## 3/3 - 0s - loss: 12884.5850 - mean_absolute_error: 71.4345 - val_loss: 8.2776 - val_mean_absolute_error: 2.5920 - 31ms/epoch - 10ms/step
## Epoch 186/200
## 3/3 - 0s - loss: 40.9125 - mean_absolute_error: 3.8916 - val_loss: 7.5745 - val_mean_absolute_error: 2.5653 - 32ms/epoch - 11ms/step
## Epoch 187/200
## 3/3 - 0s - loss: 28.7258 - mean_absolute_error: 3.4560 - val_loss: 15.9163 - val_mean_absolute_error: 3.7351 - 32ms/epoch - 11ms/step
## Epoch 188/200
## 3/3 - 0s - loss: 34.8812 - mean_absolute_error: 4.3637 - val_loss: 7.6218 - val_mean_absolute_error: 2.5566 - 30ms/epoch - 10ms/step
## Epoch 189/200
## 3/3 - 0s - loss: 26.7710 - mean_absolute_error: 3.2154 - val_loss: 196.7604 - val_mean_absolute_error: 10.4975 - 32ms/epoch - 11ms/step
## Epoch 190/200
## 3/3 - 0s - loss: 297.3047 - mean_absolute_error: 11.0045 - val_loss: 9.2098 - val_mean_absolute_error: 2.6340 - 30ms/epoch - 10ms/step
## Epoch 191/200
## 3/3 - 0s - loss: 11.2494 - mean_absolute_error: 2.5024 - val_loss: 92.5074 - val_mean_absolute_error: 6.8252 - 31ms/epoch - 10ms/step
## Epoch 192/200
## 3/3 - 0s - loss: 83.7759 - mean_absolute_error: 5.3374 - val_loss: 979.8765 - val_mean_absolute_error: 25.2076 - 32ms/epoch - 11ms/step
## Epoch 193/200
## 3/3 - 0s - loss: 8723.0312 - mean_absolute_error: 65.7257 - val_loss: 190.7277 - val_mean_absolute_error: 12.7622 - 32ms/epoch - 11ms/step
## Epoch 194/200
## 3/3 - 0s - loss: 1859.9871 - mean_absolute_error: 25.9576 - val_loss: 14.9970 - val_mean_absolute_error: 3.2210 - 30ms/epoch - 10ms/step
## Epoch 195/200
## 3/3 - 0s - loss: 186.5272 - mean_absolute_error: 6.5522 - val_loss: 874.4487 - val_mean_absolute_error: 26.2726 - 32ms/epoch - 11ms/step
## Epoch 196/200
## 3/3 - 0s - loss: 11187.2139 - mean_absolute_error: 60.0308 - val_loss: 138.4275 - val_mean_absolute_error: 8.6617 - 34ms/epoch - 11ms/step
## Epoch 197/200
## 3/3 - 0s - loss: 2969.7551 - mean_absolute_error: 36.7896 - val_loss: 1227.1619 - val_mean_absolute_error: 30.9068 - 32ms/epoch - 11ms/step
## Epoch 198/200
## 3/3 - 0s - loss: 5218.3994 - mean_absolute_error: 42.2745 - val_loss: 29165.1426 - val_mean_absolute_error: 148.7445 - 35ms/epoch - 12ms/step
## Epoch 199/200
## 3/3 - 0s - loss: 246666.6406 - mean_absolute_error: 295.3921 - val_loss: 46679.2266 - val_mean_absolute_error: 183.0828 - 32ms/epoch - 11ms/step
## Epoch 200/200
## 3/3 - 0s - loss: 176515.6719 - mean_absolute_error: 269.3751 - val_loss: 1371.3434 - val_mean_absolute_error: 30.9480 - 33ms/epoch - 11ms/step
# Evaluate the model on the test data
score <- model %>% evaluate(test_features, test_response, verbose = 0)
score ## loss mean_absolute_error
## 3489.22021 39.01598
## 1/1 - 0s - 55ms/epoch - 55ms/step
# Calculate RMSE
rmse <- round(sqrt(mean((predictions - test_response)^2)),3)
print(paste("RMSE on test data:", rmse))## [1] "RMSE on test data: 59.07"
# Calculate correlations
correlation <- round(cor(predictions, test_response),3)
print(paste("Correlation between actual and predicted values:", correlation))## [1] "Correlation between actual and predicted values: 0.452"
## Model: "sequential_2"
## ________________________________________________________________________________
## Layer (type) Output Shape Param #
## ================================================================================
## dense_7 (Dense) (None, 256) 2048
## dense_6 (Dense) (None, 128) 32896
## dense_5 (Dense) (None, 64) 8256
## dense_4 (Dense) (None, 1) 65
## ================================================================================
## Total params: 43,265
## Trainable params: 43,265
## Non-trainable params: 0
## ________________________________________________________________________________
#install.packages("plotly")
#install.packages("fastmap")
#library(fastmap)
library(plotly)
#epochs <- 50
#time <- 1:epochs
#hist_df <- data.frame(time=time, loss=history$metrics$loss, mae=history$metrics$mean_absolute_error,
# valid_loss=history$metrics$val_loss, valid_mae=history$metrics$val_mean_absolute_error)
#plot_ly(hist_df, x = ~time) %>%
# add_trace(y = ~loss, name = 'training loss', type = "scatter", mode = 'lines') %>%
# add_trace(y = ~mae, name = 'training MAE', type = "scatter", mode = 'lines+markers') %>%
# add_trace(y = ~valid_loss, name = 'validation loss', type = "scatter", mode = 'lines+markers') %>%
# add_trace(y = ~valid_mae, name = 'validation MAE', type = "scatter", mode = 'lines+markers') %>%
# layout(title = "NN Model Performance",
# legend = list(orientation = 'h'),
# yaxis = list(title = "Metric"))
hist_df <- data.frame(cases=200, real=as.matrix(test_response), predicted=predictions)
plot_ly(hist_df, x = ~real) %>%
add_trace(y = ~predicted, name = 'Scatter (Actual vs. Predicted)', type="scatter", mode = 'markers') %>%
add_lines(x = ~real, y = ~fitted(lm(predicted ~ real, hist_df)), name="LM(Pred ~ Real)") %>%
layout(title=paste0("NN Model Prediction (correlation=", correlation,")"),
legend = list(orientation = 'h'), yaxis=list(title="predicted"))## [1] "Corr(Actual, Predicted)=0.452"
## Model: "sequential_2"
## ________________________________________________________________________________
## Layer (type) Output Shape Param #
## ================================================================================
## dense_7 (Dense) (None, 256) 2048
## dense_6 (Dense) (None, 128) 32896
## dense_5 (Dense) (None, 64) 8256
## dense_4 (Dense) (None, 1) 65
## ================================================================================
## Total params: 43,265
## Trainable params: 43,265
## Non-trainable params: 0
## ________________________________________________________________________________
The training process was run for 200 epochs. An epoch is one complete pass through the entire training dataset.
The reported Root Mean Square Error (RMSE) on the test data is 59.07, relatively small considered the RMSE has a range between 0 to 10000, and thus the model could be deemed reasonably accurate.
The correlation of 0.452 suggests a moderate positive linear relationship between the actual and predicted values. This indicates that the model has learned some of the underlying patterns in the data but there’s still a substantial amount of variance that the model is not capturing. A higher correlation would indicate a better fit of the model to the data.
The model “sequential_2” is a densely connected neural network with four layers. The first layer has 256 neurons, the second 128, the third 64, and the final output layer has a single neuron, indicating the model is likely designed for regression or binary classification tasks. The model is fairly complex, with a total of 43,265 parameters, all of which are trainable. This indicates a potentially high capacity for learning from data, but also a risk of overfitting if not enough training data is provided or if proper regularization techniques are not employed. Each layer’s parameters are derived from the connections to all neurons in the preceding layer, along with a bias term for each neuron. The summary confirms that no parameters are frozen or non-trainable, meaning the entire model will be updated during the training process.
#install.packages("magick")
#install.packages("dplyr")
#py_install("Pillow", envname = "r-reticulate")
library(keras)
library(dplyr)
library(magick)
# Function to classify an image using multiple models
classify_image <- function(img_url) {
# Download the image
download.file(img_url, paste(getwd(),"results/image.png", sep="/"), mode = 'wb')
# Read the image and resize
img <- image_read(paste(getwd(),"results/image.png", sep="/")) %>% image_resize("224x224!")
img_for_display <- image_read(paste(getwd(),"results/image.png", sep="/")) %>% image_resize("800x800")
# Preprocess the image for prediction
x <- as.integer(image_data(img))
# ensure we have a 4d tensor with single element in the batch dimension,
# the preprocess the input for prediction using resnet50
x <- array_reshape(x, c(1, dim(x)))
x <- imagenet_preprocess_input(x)
# Initialize list to store predictions
predictions_list <- list()
# Model 1: ResNet50
model_resnet50 <- application_resnet50(weights = 'imagenet')
preds_resnet50 <- predict(model_resnet50, x)
predictions_list$resnet50 <- imagenet_decode_predictions(preds_resnet50, top = 5)[[1]]
# Model 2: VGG19
model_vgg19 <- application_vgg19(weights = 'imagenet')
preds_vgg19 <- predict(model_vgg19, x)
predictions_list$vgg19 <- imagenet_decode_predictions(preds_vgg19, top = 5)[[1]]
# Model 3: VGG16
model_vgg16 <- application_vgg16(weights = 'imagenet')
preds_vgg16 <- predict(model_vgg16, x)
predictions_list$vgg16 <- imagenet_decode_predictions(preds_vgg16, top = 5)[[1]]
# Return a list containing both the predictions and the image for display
return(list(predictions = predictions_list, image = img_for_display))
}
# Use the function to classify an image
Daisy <- classify_image("https://fileinfo.com/img/ss/xl/jpeg_43-2.jpg")## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 475ms/epoch - 475ms/step
## 1/1 - 0s - 391ms/epoch - 391ms/step
Volcano <- classify_image("https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/rockcms/2023-05/230522-Mexico-volcano-Popocatepetl-eruption-lava-ac-902p-373199.jpg")## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 437ms/epoch - 437ms/step
## 1/1 - 0s - 413ms/epoch - 413ms/step
Brain <- classify_image("https://media.wired.com/photos/59324e5452d99d6b984dd9a0/master/pass/brain1.jpg")## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 477ms/epoch - 477ms/step
## 1/1 - 0s - 436ms/epoch - 436ms/step
## # A tibble: 1 × 7
## format width height colorspace matte filesize density
## <chr> <int> <int> <chr> <lgl> <int> <chr>
## 1 JPEG 800 554 sRGB FALSE 0 72x72
## $resnet50
## class_name class_description score
## 1 n11939491 daisy 0.75317287
## 2 n03457902 greenhouse 0.08539081
## 3 n03930313 picket_fence 0.03835879
## 4 n03782006 monitor 0.01346979
## 5 n04485082 tripod 0.01004972
##
## $vgg19
## class_name class_description score
## 1 n11939491 daisy 0.733606100
## 2 n03991062 pot 0.038622856
## 3 n03891251 park_bench 0.023287510
## 4 n03930313 picket_fence 0.020245489
## 5 n02280649 cabbage_butterfly 0.009900589
##
## $vgg16
## class_name class_description score
## 1 n11939491 daisy 0.52277225
## 2 n11879895 rapeseed 0.19854690
## 3 n03930313 picket_fence 0.04138704
## 4 n02280649 cabbage_butterfly 0.03947129
## 5 n02281406 sulphur_butterfly 0.02848973
ResNet50: Daisy (75.32%), greenhouse (8.54%), picket fence (3.84%), monitor (1.35%), tripod (1.00%).
VGG19: Daisy (73.36%), pot (3.86%), park bench (2.33%), picket fence (2.02%), cabbage butterfly (0.99%).
VGG16: Daisy (52.28%), rapeseed (19.85%), picket fence (4.14%), cabbage butterfly (3.95%), sulphur butterfly (2.85%).
## # A tibble: 1 × 7
## format width height colorspace matte filesize density
## <chr> <int> <int> <chr> <lgl> <int> <chr>
## 1 JPEG 800 533 sRGB FALSE 0 72x72
## $resnet50
## class_name class_description score
## 1 n09472597 volcano 1.000000e+00
## 2 n04330267 stove 3.136750e-10
## 3 n09288635 geyser 1.817625e-10
## 4 n03347037 fire_screen 1.810664e-10
## 5 n02939185 caldron 5.460157e-11
##
## $vgg19
## class_name class_description score
## 1 n09472597 volcano 9.999828e-01
## 2 n04456115 torch 1.468843e-05
## 3 n03729826 matchstick 1.558330e-06
## 4 n01443537 goldfish 3.081052e-07
## 5 n04330267 stove 2.524659e-07
##
## $vgg16
## class_name class_description score
## 1 n09472597 volcano 9.997590e-01
## 2 n04456115 torch 1.812786e-04
## 3 n01443537 goldfish 1.262610e-05
## 4 n04330267 stove 1.225619e-05
## 5 n01910747 jellyfish 9.439173e-06
ResNet50: Volcano (100%), stove (0.00000003137%), geyser (0.00000001818%), fire screen (0.00000001811%), caldron (0.00000000546%).
VGG19: Volcano (99.98%), torch (0.0015%), matchstick (0.00016%), goldfish (0.00003%), stove (0.000025%).
VGG16: Volcano (99.76%), torch (0.018%), goldfish (0.0013%), stove (0.0012%), jellyfish (0.00094%).
## # A tibble: 1 × 7
## format width height colorspace matte filesize density
## <chr> <int> <int> <chr> <lgl> <int> <chr>
## 1 JPEG 800 640 sRGB FALSE 0 72x72
## $resnet50
## class_name class_description score
## 1 n01917289 brain_coral 0.06373768
## 2 n03041632 cleaver 0.06237394
## 3 n03627232 knot 0.05447683
## 4 n01930112 nematode 0.05420838
## 5 n03720891 maraca 0.04702024
##
## $vgg19
## class_name class_description score
## 1 n13037406 gyromitra 0.66916829
## 2 n07720875 bell_pepper 0.04610845
## 3 n04599235 wool 0.03687157
## 4 n01917289 brain_coral 0.02477533
## 5 n07695742 pretzel 0.02373043
##
## $vgg16
## class_name class_description score
## 1 n01917289 brain_coral 0.21701127
## 2 n13037406 gyromitra 0.19762117
## 3 n04599235 wool 0.05079585
## 4 n12267677 acorn 0.02914893
## 5 n03840681 ocarina 0.02528968
ResNet50: Brain coral (6.37%), cleaver (6.24%), knot (5.45%), nematode (5.42%), maraca (4.70%).
VGG19: Gyromitra (66.92%), bell pepper (4.61%), wool (3.69%), brain coral (2.48%), pretzel (2.37%).
VGG16: Brain coral (21.70%), gyromitra (19.76%), wool (5.08%), acorn (2.91%), ocarina (2.53%).
The predictions reflect the confidence level of the models in recognizing the objects in the images, with the score representing the probability assigned to each label by the respective models.
For the Daisy image, all three models correctly identify the daisy with varying degrees of confidence, indicating that the image is likely a clear representation of a daisy and that the models have been well-trained to recognize this class.
For the Volcano image, the models are extremely confident that the image is of a volcano. This high confidence across all models suggests that the image has very distinctive features that are strongly associated with the concept of a volcano.
For the Brain image, the models do not predict a human brain but rather objects with a similar appearance, like brain coral and a type of fungus. This indicates a limitation in the models’ ability to correctly classify this image, which could be due to a lack of representative training data for human brains or the complexity of the image that does not match well with the patterns learned from the ImageNet database.
Based on the available data, ResNet50 seems to be the most reliable, having the highest confidence in the daisy and volcano images, and providing a thematically related prediction for the brain image.VGG19’s top prediction on the brain is less related than the others, which might lean toward it being the less accurate among the three models.
These results show the strengths and limitations of pre-trained ImageNet models when applied to specific images. While they are generally good at recognizing a wide range of objects, their accuracy can vary depending on the similarity of the test images to the training data and the distinctiveness of the image features.